#!/usr/bin/env python

from os.path import join, isdir, isfile
import argparse
import glob
import json
from os import environ
import sys

import boto3
import networkx as nx


# From https://stackoverflow.com/questions/3041986/apt-command-line-interface-like-yes-no-input # noqa
def query_yes_no(question, default="yes"):
    """Ask a yes/no question via raw_input() and return their answer.

    "question" is a string that is presented to the user.
    "default" is the presumed answer if the user just hits <Enter>.
        It must be "yes" (the default), "no" or None (meaning
        an answer is required of the user).

    The "answer" return value is True for "yes" or False for "no".
    """
    valid = {"yes": True, "y": True, "ye": True,
             "no": False, "n": False}
    if default is None:
        prompt = " [y/n] "
    elif default == "yes":
        prompt = " [Y/n] "
    elif default == "no":
        prompt = " [y/N] "
    else:
        raise ValueError("invalid default answer: '%s'" % default)

    while True:
        sys.stdout.write(question + prompt)
        choice = raw_input().lower()
        if default is not None and choice == '':
            return valid[default]
        elif choice in valid:
            return valid[choice]
        else:
            sys.stdout.write("Please respond with 'yes' or 'no' "
                             "(or 'y' or 'n').\n")


def read_experiment_files(path):
    run_name_to_file_path = {}
    run_name_to_parent_run_names = {}

    for file_path in glob.iglob(path):
        with open(file_path, 'r') as exp_file:
            exp = json.load(exp_file)
            run_name = exp['run_name']
            parent_run_names = get_parent_run_names(exp)
            run_name_to_file_path[run_name] = file_path
            run_name_to_parent_run_names[run_name] = parent_run_names

    return run_name_to_file_path, run_name_to_parent_run_names


def get_parent_run_names(exp):
    return exp.get('aggregate_run_names')


def build_dep_graph(run_name_to_parent_run_names):
    g = nx.DiGraph()

    for run_name in run_name_to_parent_run_names.keys():
        g.add_node(run_name)

    for child_run_name, parent_run_names in \
            run_name_to_parent_run_names.items():
        if parent_run_names is not None:
            for parent_run_name in parent_run_names:
                g.add_edge(parent_run_name, child_run_name)

    return g


def submit_job(branch_name, file_path, tasks, run_name,
               attempts, parent_job_ids=None, cpu=False):
    s3_bucket = environ.get('S3_BUCKET')

    job_name = run_name.replace('/', '-')

    # Make file_path be rooted at experiments/ instead of src/ since
    # src/ is the root of the directory structure inside the container.
    file_path = file_path[file_path.find('experiments/'):]

    command = ['run_experiment.sh', s3_bucket, branch_name, file_path]
    command.extend(tasks)

    dependsOn = []
    if parent_job_ids is not None:
        dependsOn = [{'jobId': job_id} for job_id in parent_job_ids]

    client = boto3.client('batch')

    job_queue = 'raster-vision-cpu' if cpu else \
        'raster-vision-gpu'
    job_definition = 'raster-vision-cpu' if cpu else \
        'raster-vision-gpu'

    job_id = client.submit_job(
        jobName=job_name,
        jobQueue=job_queue,
        jobDefinition=job_definition,
        containerOverrides={
            'command': command
        },
        dependsOn=dependsOn,
        retryStrategy={
            'attempts': attempts
        })['jobId']

    print(
        'Submitted job with jobName={} and jobId={}'.format(job_name, job_id))

    return job_id


def prompt_user(branch_name, experiment_path, nb_jobs):
    question = 'Are your experiment files pushed to the remote branch {}' \
        .format(branch_name)
    pushed_branch = query_yes_no(question, default='no')

    want_to_run = False
    if pushed_branch:
        question = 'Are you sure you want to run {} jobs?'.format(nb_jobs)
        want_to_run = query_yes_no(question, default='no')

    return pushed_branch, want_to_run


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('branch_name', help='Branch with code to run on AWS')
    parser.add_argument(
        'experiment_path',
        help='Directory with experiment files rooted at src/experiments/...')
    parser.add_argument(
        'tasks', nargs='*',
        help='Space-delimited list of tasks to run on AWS. ' +
             'Can be empty to run all tasks.')
    parser.add_argument(
        '--attempts', type=int, default=5, help='Number of times to retry job')
    parser.add_argument(
        '--cpu', dest='cpu', action='store_true', help='Use CPU EC2 instances')
    return parser.parse_args()


def run():
    args = parse_args()
    print('Branch name: {}'.format(args.branch_name))
    print('Experiment path: {}'.format(args.experiment_path))
    print('Tasks: {}'.format(args.tasks))
    print('Number of attempts: {}'.format(args.attempts))
    print('Using {}'.format('CPU' if args.cpu else 'GPU'))

    experiment_path = join(args.experiment_path, '*.json') \
        if isdir(args.experiment_path) else args.experiment_path
    experiment_paths = list(glob.iglob(experiment_path))
    nb_jobs = len(experiment_paths)

    pushed_branch, want_to_run = prompt_user(
        args.branch_name, experiment_path, nb_jobs)

    if pushed_branch and want_to_run:
        if nb_jobs == 1:
            with open(experiment_path, 'r') as exp_file:
                exp = json.load(exp_file)
                run_name = exp['run_name']
                submit_job(
                    args.branch_name, experiment_path, args.tasks, run_name,
                    args.attempts, cpu=args.cpu)
        else:
            # If directory of files, then build dependency graph and run
            # them in the right order, feeding the parent job ids to batch.
            run_name_to_file_path, run_name_to_parent_run_names = \
                read_experiment_files(experiment_path)
            dep_graph = build_dep_graph(run_name_to_parent_run_names)
            sorted_run_names = nx.topological_sort(dep_graph)

            run_name_to_job_id = {}
            for run_name in sorted_run_names:
                file_path = run_name_to_file_path[run_name]

                parent_run_names = run_name_to_parent_run_names[run_name]
                parent_job_ids = None
                if parent_run_names is not None:
                    parent_job_ids = [run_name_to_job_id[parent_run_name]
                                      for parent_run_name in parent_run_names]

                run_name_to_job_id[run_name] = submit_job(
                    args.branch_name, file_path, args.tasks, run_name,
                    args.attempts, parent_job_ids, cpu=args.cpu)


if __name__ == '__main__':
    run()
